Summarize escape across all sera alongside functional effects¶

In [1]:
import altair as alt

import pandas as pd

import polyclonal.alphabets
from polyclonal.plot import color_gradient_hex

_ = alt.data_transformers.disable_max_rows()

The next cell is tagged as parameters for papermill parameterization:

In [2]:
site_numbering_map_csv = None
func_effects_csv = None
sera = None
chart = None
csv_file = None
receptor_affinity_csv = None
In [3]:
# Parameters
sera = {
    "sera_493C_highACE2": "results/antibody_escape/averages/sera_493C_highACE2_mut_effect.csv",
    "sera_498C_highACE2": "results/antibody_escape/averages/sera_498C_highACE2_mut_effect.csv",
    "sera_500C_highACE2": "results/antibody_escape/averages/sera_500C_highACE2_mut_effect.csv",
    "sera_503C_highACE2": "results/antibody_escape/averages/sera_503C_highACE2_mut_effect.csv",
    "sera_493C_mediumACE2": "results/antibody_escape/averages/sera_493C_mediumACE2_mut_effect.csv",
    "sera_498C_mediumACE2": "results/antibody_escape/averages/sera_498C_mediumACE2_mut_effect.csv",
    "sera_500C_mediumACE2": "results/antibody_escape/averages/sera_500C_mediumACE2_mut_effect.csv",
    "sera_501C_mediumACE2": "results/antibody_escape/averages/sera_501C_mediumACE2_mut_effect.csv",
    "sera_503C_mediumACE2": "results/antibody_escape/averages/sera_503C_mediumACE2_mut_effect.csv",
    "sera_287C_mediumACE2": "results/antibody_escape/averages/sera_287C_mediumACE2_mut_effect.csv",
    "sera_288C_mediumACE2": "results/antibody_escape/averages/sera_288C_mediumACE2_mut_effect.csv",
    "sera_343C_mediumACE2": "results/antibody_escape/averages/sera_343C_mediumACE2_mut_effect.csv",
    "sera_497C_mediumACE2": "results/antibody_escape/averages/sera_497C_mediumACE2_mut_effect.csv",
    "sera_505C_mediumACE2": "results/antibody_escape/averages/sera_505C_mediumACE2_mut_effect.csv",
}
site_numbering_map_csv = "data/site_numbering_map.csv"
func_effects_csv = "results/func_effects/averages/293T_high_ACE2_entry_func_effects.csv"
receptor_affinity_csv = (
    "results/receptor_affinity/averages/monomeric_ACE2_mut_effect.csv"
)
chart = "results/summaries/escape_summary_nolegend.html"
csv_file = "results/summaries/escape_summary.csv"

Some configuration for plot:

In [4]:
times_seen = 3  # only include mutations with times_seen >= this
frac_models = 1  # only include mutations in >= this fraction of models / selections
escape_stat = "escape_median"  # for each sera, use this escape value (mean or median)
init_site_escape_stat = "mean"  # default site escape stat to show
init_min_func_effect = -3  # default minimum functional effect to show
init_floor_escape_at_zero = True  # default on whether to floor escape at zero
init_min_receptor_effect = -1  # default minimum functional effect to show

# for heatmap colors
escape_negative_color = "#0072B2"  # french blue
escape_positive_color = "#E69F00"  # orange
escape_max_at_least = 1
escape_min_at_least = -1

func_positive_color = "#009E73"  # green
func_negative_color = "#CC79A7"  # wild orchid
func_max_at_least = 1
func_min_at_least = 0


receptor_positive_color = "#FF715B"  # pink
receptor_negative_color = "#F3C13A"  # yellow
receptor_max_at_least = 1
receptor_min_at_least = 0

Read the escape data and add site numbering and functional effect data:

In [5]:
escape_tidy = (
    pd.concat([pd.read_csv(f).assign(serum=s) for s, f in sera.items()])
    .rename(columns={escape_stat: "escape"})
    .query("frac_models >= @frac_models")
    .query("times_seen >= @times_seen")
    [["epitope", "serum", "site", "wildtype", "mutant", "escape"]]
)

assert escape_tidy["epitope"].nunique() == 1, "averaging only works for one epitope"

escape = (
    escape_tidy
    .pivot_table(
        index=["site", "wildtype", "mutant"],
        columns="serum",
        values="escape",
    ).reset_index()
    .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
)

assert escape["site_mutant"].nunique() == len(escape)

site_numbering_map = (
    pd.read_csv(site_numbering_map_csv)
    .rename(columns={"reference_site": "site"})
    [["site", "sequential_site", "region"]]
)

func_effects = (
    pd.read_csv(func_effects_csv)
    .rename(columns={"effect": "functional effect"})
    .query("times_seen >= @times_seen")
    .assign(frac_selections=lambda x: x["n_selections"] / x["n_selections"].max())
    .query("frac_selections >= @frac_models")
    [["site", "wildtype", "mutant", "functional effect"]]
)
# add wildtype functional effects of zero
func_effects = (
    pd.concat(
        [
            func_effects,
            (
                func_effects
                [["site", "wildtype"]]
                .drop_duplicates()
                .assign(
                    mutant=lambda x: x["wildtype"],
                    **{"functional effect": 0},
                )
            )
        ],
        ignore_index=True,
    )
    .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
    .merge(site_numbering_map, on="site", validate="many_to_one")
)

assert func_effects["site_mutant"].nunique() == len(func_effects)

escape
Out[5]:
serum site wildtype mutant sera_287C_mediumACE2 sera_288C_mediumACE2 sera_343C_mediumACE2 sera_493C_highACE2 sera_493C_mediumACE2 sera_497C_mediumACE2 sera_498C_highACE2 sera_498C_mediumACE2 sera_500C_highACE2 sera_500C_mediumACE2 sera_501C_mediumACE2 sera_503C_highACE2 sera_503C_mediumACE2 sera_505C_mediumACE2 site_mutant
0 2 F C -0.090660 -0.046130 0.168700 -0.005757 0.070280 0.206800 0.011310 -0.024830 0.029200 0.011860 -0.092530 0.067350 0.042310 -0.134400 2C
1 2 F L -0.039860 -0.013830 -0.068340 0.020580 -0.019610 -0.093650 -0.120000 0.190600 -0.031700 0.013100 0.058580 0.031060 0.196100 -0.035530 2L
2 2 F S 0.433800 -0.182900 -0.447500 -0.020950 0.071660 0.227500 0.074010 -0.112900 0.007172 0.013640 0.129400 0.022210 0.119300 0.064940 2S
3 3 V A -0.013810 0.000359 0.258300 -0.012670 -0.084880 0.330600 -0.006777 0.070390 0.002656 -0.088690 -0.280600 -0.033720 -0.030210 0.078700 3A
4 3 V F 0.076880 -0.222400 -0.030260 0.031730 0.293300 0.323700 -0.070420 0.167500 0.051390 0.009024 0.056410 -0.093320 0.331800 0.121200 3F
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
5965 1252 S F -0.009048 -0.011390 -0.001699 -0.021650 0.000229 0.001077 -0.001636 -0.006037 0.004098 -0.007209 -0.015120 -0.007979 -0.006033 -0.006353 1252F
5966 1252 S P -0.005321 -0.000230 -0.002522 0.002896 -0.008243 0.003474 0.000020 0.001526 -0.008855 -0.005124 -0.009008 -0.011420 -0.004026 -0.010780 1252P
5967 1252 S T -0.003340 0.005943 0.004079 0.004959 0.005474 -0.007287 -0.006334 0.007377 -0.001574 0.001835 0.009529 -0.007007 -0.001255 -0.004719 1252T
5968 1252 S Y -0.005370 -0.001432 -0.004039 0.007259 0.000128 -0.019020 0.008188 -0.009239 -0.023370 -0.010060 0.005224 -0.011650 -0.011420 -0.006647 1252Y
5969 1253 * R -0.001390 0.002490 -0.001332 -0.001703 0.005983 -0.008570 0.011640 -0.002179 -0.015250 -0.000918 0.002691 -0.003436 -0.000316 0.001364 1253R

5970 rows × 18 columns

In [6]:
func_effects
Out[6]:
site wildtype mutant functional effect site_mutant sequential_site region
0 1 M I -6.323000 1I 1 other
1 1 M M 0.000000 1M 1 other
2 2 F C 0.101000 2C 2 other
3 2 F L 0.094320 2L 2 other
4 2 F S 0.058440 2S 2 other
... ... ... ... ... ... ... ...
8520 1252 S T 0.003655 1252T 1248 other
8521 1252 S Y 0.058880 1252Y 1248 other
8522 1252 S S 0.000000 1252S 1248 other
8523 1253 * R 0.092380 1253R 1249 other
8524 1253 * * 0.000000 1253* 1249 other

8525 rows × 7 columns

In [7]:
# add wildtype functional effects of zero
receptor_affinity = (
    pd.read_csv(receptor_affinity_csv)
    .rename(columns={"effect": "affinity_median"})
    .query("times_seen >= @times_seen")
    .assign(frac_selections=lambda x: x["n_models"] / x["n_models"].max())
    .query("frac_selections >= @frac_models")
    [["site", "wildtype", "mutant", "affinity_median"]]
)


receptor_affinity = (
    pd.concat(
        [
            receptor_affinity,
            (
                receptor_affinity
                [["site", "wildtype"]]
                .drop_duplicates()
                .assign(
                    mutant=lambda x: x["wildtype"],
                    **{"affinity_median": 0},
                )
            )
        ],
        ignore_index=True,
    )
    .assign(site_mutant=lambda x: x["site"].astype(str) + x["mutant"])
    .merge(site_numbering_map, on="site", validate="many_to_one")
)

assert receptor_affinity["site_mutant"].nunique() == len(receptor_affinity)
receptor_affinity
Out[7]:
site wildtype mutant affinity_median site_mutant sequential_site region
0 2 F C 0.021510 2C 2 other
1 2 F L -0.269800 2L 2 other
2 2 F S -0.056420 2S 2 other
3 2 F F 0.000000 2F 2 other
4 3 V A -0.049770 3A 3 other
... ... ... ... ... ... ... ...
7272 1252 S T -0.114200 1252T 1248 other
7273 1252 S Y 0.000546 1252Y 1248 other
7274 1252 S S 0.000000 1252S 1248 other
7275 1253 * R 0.041330 1253R 1249 other
7276 1253 * * 0.000000 1253* 1249 other

7277 rows × 7 columns

Now make a site summary escape plot for all the sera:

In [8]:
floor_escape_at_zero = alt.param(
    value=init_floor_escape_at_zero,
    name="floor_escape_at_zero",
    bind=alt.binding_radio(options=[True, False], name="floor escape at zero"),
)

site_stats = ["mean", "sum", "max", "min"]
site_escape_selection = alt.selection_point(
    fields=["site escape statistic"],
    bind=alt.binding_select(
        options=site_stats,
        name="site escape statistic",
    ),
    value=init_site_escape_stat,
)

site_selection = alt.selection_point(fields=["site"], on="mouseover", empty=False)

func_effects_slider = alt.param(
    value=init_min_func_effect,
    name="func_effects_slider",
    bind=alt.binding_range(
        name="minimum mutation functional effect",
        min=func_effects["functional effect"].min(),
        max=0,
    ),
)

receptor_affinity_slider = alt.param(
    value=init_min_receptor_effect,
    name="receptor_affinity_slider",
    bind=alt.binding_range(
        name="minimum ACE2 affinity",
        min=receptor_affinity["affinity_median"].min(),
        max=0,
    ),
)

site_brush = alt.selection_interval(
    encodings=["x"],
    mark=alt.BrushConfig(stroke="black", strokeWidth=2, fillOpacity=0),
    empty=True,
)

site_escape_width = 1200  # width of site escape chart

site_escape_base = (
    alt.Chart(escape)
    .encode(
        y=alt.Y(
            "escape:Q",
            scale=alt.Scale(nice=False, padding=10),
            axis=alt.Axis(grid=False),
        ),
        tooltip=[
            "site",
            alt.Tooltip("escape:Q", format=".2f"),
            "wildtype",
            "sequential_site:Q",
            "serum:N",
            "region:N",
        ],
    )
)

site_escape_lines = site_escape_base.mark_line(size=0.75, opacity=1)

site_escape_points = site_escape_base.encode(
    strokeWidth=alt.condition(site_selection, alt.value(3), alt.value(0)),
).mark_circle(filled=True, opacity=1, stroke="red", size=20)

site_escape_lines_and_points = (
    (site_escape_lines + site_escape_points)
    .transform_fold(fold=list(sera), as_=["serum", "escape_orig"])
    # floor escape at zero if selected
    .transform_calculate(
        escape=alt.expr.if_(
            floor_escape_at_zero,
            alt.expr.max(alt.datum["escape_orig"], 0),
            alt.datum["escape_orig"],
        )
    )
    # filter on functional effects
    .transform_lookup(
        lookup="site_mutant",
        from_=alt.LookupData(
            func_effects,
            key="site_mutant",
            fields=["functional effect"],
        ),
    )
    .transform_lookup(
        lookup="site_mutant",
        from_=alt.LookupData(
            receptor_affinity,
            key="site_mutant",
            fields=["affinity_median"],
        ),
    )
    .transform_filter(alt.datum["functional effect"] >= func_effects_slider)
    .transform_filter(alt.datum["affinity_median"] >= receptor_affinity_slider)
    # compute site statistics from mutation statistics
    .transform_aggregate(
        **{stat: f"{stat}(escape)" for stat in site_stats},
        groupby=["site", "serum", "wildtype"],
    )
    # filter on site statistic of interest
    .transform_fold(fold=site_stats, as_=["site escape statistic", "escape"])
    .transform_filter(site_escape_selection)
    # get sequential sites and regions
    .transform_lookup(
        lookup="site",
        from_=alt.LookupData(
            site_numbering_map,
            key="site",
            fields=["sequential_site", "region"],
        ),
    )
)

site_escape = (
    site_escape_lines_and_points
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site:Q"),
            axis=alt.Axis(labelOverlap=True, grid=False),
        ),
        opacity=alt.condition(site_brush, alt.value(1), alt.value(0.4)),
        color=alt.value("gray"),
    )
    .properties(height=90, width=site_escape_width)
    .facet(
        facet=alt.Facet(
            "serum:N",
            title="individual sera",
            header=alt.Header(
                labelOrient="right",
                labelFontSize=10,
                labelPadding=3,
                titleOrient="right",
                titlePadding=3,
            ),
        ),
        columns=1,
        spacing=0,
    )
)

site_mean_escape = (
    site_escape_lines_and_points
    # average missing values as zero
    .transform_calculate(
        escape=alt.expr.if_(
            alt.expr.isValid(alt.datum["escape"]),
            alt.datum["escape"],
            0,
        ),
    )
    # take mean over sera
    .transform_aggregate(
        escape="mean(escape)",
        groupby=["site", "wildtype", "sequential_site", "region"],
    )
    .transform_calculate(serum="'mean of all sera'")
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site:Q"),
            axis=None,
        ),
        opacity=alt.condition(site_brush, alt.value(1), alt.value(0.4)),
        color=alt.value("black")
    )
    .properties(
        height=70,
        width=site_escape_width,
        title=alt.TitleParams(
            "mean of sera", fontSize=11, fontWeight="bold", orient="right",
        ),
    )
)

region_bar = (
    alt.Chart(site_numbering_map)
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site:Q"),
            axis=None,
        ),
        color=alt.Color(
            "region",
            scale=alt.Scale(domain=site_numbering_map["region"].unique()),
        ),
        tooltip=["site", "region", "sequential_site"],
    )
    .mark_rect()
    .properties(width=site_escape_width, height=9)
)       

site_chart = alt.vconcat(
    region_bar, 
    alt.vconcat(site_mean_escape, site_escape, spacing=3).add_params(
        site_escape_selection,
        site_selection,
        func_effects_slider,
        receptor_affinity_slider,
        floor_escape_at_zero,
    ),
    spacing=0,
).add_params(site_brush)

site_chart
Out[8]:

Now prepare to plot the heatmaps. First, create a data frame that has the functional effects and the average escape across sera (averaging mutations missing for a serum as zero for that serum):

In [9]:
heatmap_data = (
    pd.concat(
        [
            escape,
            # add wildtype with zero escape
            (
                escape
                [["site", "wildtype"]]
                .drop_duplicates()
                .assign(mutant=lambda x: x["wildtype"])
            ),
        ],
        ignore_index=True,
    )
    .fillna(0)
    .assign(escape=lambda x: x[list(sera)].mean(axis=1))
    .drop(columns=list(sera) + ["site_mutant"])
    .merge(func_effects, validate="one_to_one", how="outer")
    .merge(receptor_affinity, validate="one_to_one", how="outer")
    .drop(columns=["sequential_site", "region"])
    .merge(site_numbering_map, validate="many_to_one")
    .assign(
        escape=lambda x: x["escape"].where(
            x["wildtype"] != x["mutant"],
            0,
        ),
    )
    .drop(columns="site_mutant")
)

Write these data to a CSV:

In [10]:
print(f"Writing summary data to {csv_file}")

(
    heatmap_data
    .merge(
        heatmap_data
        .query("wildtype != mutant")
        .groupby("site", as_index=False)
        .aggregate(mean_site_escape=pd.NamedAgg("escape", "mean")),
        how="outer",
        validate="many_to_one",
    )
    .to_csv(csv_file, index=False, float_format="%.4g")
)
Writing summary data to results/summaries/escape_summary.csv

Make heatmaps:

In [11]:
cell_size = 9  # heatmap cell size

alphabet = polyclonal.alphabets.biochem_order_aas(func_effects["mutant"].unique())

heatmap_base = (
    alt.Chart(heatmap_data)
    # convert null values to NaN so they show as NaN in tooltips rather than as 0.0
    .transform_calculate(
        escape_floored=alt.expr.if_(
            floor_escape_at_zero,
            alt.expr.max(alt.datum["escape"], 0),
            alt.datum["escape"],
        ),
        **{
            col: alt.expr.if_(
                alt.expr.isFinite(alt.datum[col]),
                alt.datum[col],
                alt.expr.NaN,
            )
            for col in ["escape", "functional effect", "affinity_median"]
        }
    )
    .encode(
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            axis=alt.Axis(labelFontSize=9, ticks=False),
        ),
        y=alt.Y(
            "mutant:N",
            title="amino acid",
            sort=alphabet,
            axis=alt.Axis(labelFontSize=9, ticks=False),
        ),
    )
    .properties(width=alt.Step(cell_size), height=alt.Step(cell_size))
    .add_params(func_effects_slider, receptor_affinity_slider, floor_escape_at_zero)
)

# mark X for wildtype
heatmap_wildtype = (
    heatmap_base
    .transform_filter(alt.datum["wildtype"] == alt.datum["mutant"])
    .mark_text(text="x", color="black")
)

# gray background for missing values
heatmap_bg = (
    heatmap_base
    .transform_impute(
        impute="_stat_dummy",
        key="mutant",
        keyvals=alphabet,
        groupby=["site"],
        value=None,
    )
    .mark_rect(color="#E0E0E0")
)

tooltips = [
    "site",
    "mutant",
    alt.Tooltip("escape", format=".2f"),
    alt.Tooltip("functional effect", format=".2f"),
    alt.Tooltip("affinity_median", format=".2f"),
    "wildtype",
    "sequential_site",
    "region",
]

legend=alt.Legend(
    orient="left",
    titleOrient="left",
    gradientLength=100,
    gradientThickness=10,
    gradientStrokeColor="black",
    gradientStrokeWidth=0.5,
)

# heatmap for escape
escape_heatmap = (
    heatmap_base
    .transform_filter(
        (alt.datum["functional effect"] >= func_effects_slider)
        | (alt.datum["wildtype"] == alt.datum["mutant"])
    )
    .transform_filter(
        (alt.datum["affinity_median"] >= receptor_affinity_slider)
        | (alt.datum["wildtype"] == alt.datum["mutant"])
    )    
    .encode(
        # turn off x-labels for this heatmap since it is stacked
        x=alt.X(
            "site:N",
            sort=alt.SortField("sequential_site"),
            title=None,
            axis=alt.Axis(ticks=False, labels=False),
        ),
        color=alt.Color(
            "escape_floored:Q",
            title="escape",
            legend=legend,
            scale=alt.Scale(
                zero=True,
                nice=False,
                type="linear",
                domainMid=0,
                domainMax=max(escape_max_at_least, heatmap_data["escape"].max()),
                domainMin=alt.ExprRef(
                    f"if(floor_escape_at_zero, 0, {escape_min_at_least})"
                ),
                range=(
                    color_gradient_hex(escape_negative_color, "white", n=20)
                    + color_gradient_hex("white", escape_positive_color, n=20)[1:]
                ),
            ),
        ),
        tooltip=tooltips,
    )
    .mark_rect(stroke="black")
)

# heatmap for func effect filtered escape
escape_func_filtered_heatmap = (
    heatmap_base
    .transform_filter(
        (alt.datum["functional effect"] < func_effects_slider)
        & (alt.datum["wildtype"] != alt.datum["mutant"])
    )
    .transform_calculate(filtered="''")
    .encode(
        tooltip=tooltips,
        color=alt.Color(
            "filtered:N",
            title=["functionally", "deleterious"],
            scale=alt.Scale(range=["silver"]),
            legend=None,
        ),
    )
    .mark_rect(stroke="black")
)

# heatmap for functional effects
func_heatmap = (
    heatmap_base
    .encode(
        color=alt.Color(
            "functional effect",
            legend=legend,
            scale=alt.Scale(
                zero=True,
                nice=False,
                type="linear",
                clamp=True,
                domainMid=0,
                domainMax=max(func_max_at_least, heatmap_data["functional effect"].max()),
                domainMin=alt.ExprRef(f"min(func_effects_slider, {func_min_at_least})"),
                range=(
                    color_gradient_hex(func_negative_color, "white", n=20)
                    + color_gradient_hex("white", func_positive_color, n=20)[1:]
                ),
            ),
        ),
        tooltip=tooltips,
    )
    .mark_rect(stroke="black")
)

# heatmap for receptor affinity filtered escape
escape_receptor_filtered_heatmap = (
    heatmap_base
    .transform_filter(
        (alt.datum["affinity_median"] < receptor_affinity_slider)
        & (alt.datum["wildtype"] != alt.datum["mutant"])
    )
    .transform_calculate(filtered="''")
    .encode(
        tooltip=tooltips,
        color=alt.Color(
            "filtered:N",
            title=["affinity", "deleterious"],
            scale=alt.Scale(range=["silver"]),
            legend=None,
        ),
    )
    .mark_rect(stroke="black")
)

# heatmap for receptor affinity effects
receptor_heatmap = (
    heatmap_base
    .encode(
        color=alt.Color(
            "affinity_median",
            legend=legend,
            scale=alt.Scale(
                zero=True,
                nice=False,
                type="linear",
                clamp=True,
                domainMid=0,
                domainMax=max(receptor_max_at_least, heatmap_data["affinity_median"].max()),
                domainMin=alt.ExprRef(f"min(receptor_affinity_slider, {receptor_min_at_least})"),
                range=(
                    color_gradient_hex(receptor_negative_color, "white", n=20)
                    + color_gradient_hex("white", receptor_positive_color, n=20)[1:]
                ),
            ),
        ),
        tooltip=tooltips,
    )
    .mark_rect(stroke="black")
)

heatmap = alt.vconcat(
    heatmap_bg + escape_heatmap + escape_func_filtered_heatmap + escape_receptor_filtered_heatmap + heatmap_wildtype,
    heatmap_bg + func_heatmap + heatmap_wildtype,
    heatmap_bg + receptor_heatmap + heatmap_wildtype,
    spacing=1,
).resolve_scale(color="independent")

heatmap
Out[11]:

Make merged chart with everything:

In [12]:
merged_chart = alt.vconcat(
    site_chart,
    heatmap.transform_filter(site_brush),
    spacing=5,
).configure_legend(orient="left")

print(f"Saving to {chart}")
merged_chart.save(chart)

merged_chart
Saving to results/summaries/escape_summary_nolegend.html
Out[12]:
In [ ]: